import os
import sys
import pickle

from transformers import AutoTokenizer
from tokenizers import AddedToken


dirpath = sys.argv[1]
tokenizer_name = sys.argv[2]


LANGS = ["python", "java", "c", "c_sharp", "php", "go", "javascript", "ruby"]
NODE_START_TOK = "(_."
NODE_END_TOK = "._)"
SPECIAL_TOKENS = ["indent", "dedent", "newline", "line_break"]


def load_files():
    tokendict = {}

    for lang in LANGS:
        tokendict_fpath = os.path.join(dirpath, f"token_dict_{lang}.pkl")

        with open(tokendict_fpath, "rb") as f:
            tokendict_lang = pickle.load(f)
            tokendict.update(tokendict_lang)

    for lang, vals in tokendict.items():
        print(lang, len(vals["nonbpe"]))

    return tokendict


def keep_token(token):
    token = token.strip()

    if token.startswith(NODE_START_TOK) or token.endswith(NODE_END_TOK):
        return True
    if token in SPECIAL_TOKENS:
        return True
    return False


def add_tokens(tokenizer, tokendict):
    all_nonbpe_toks = set()

    for lang in LANGS:
        lang_toks = set(tokendict[lang]["nonbpe"])
        lang_toks = [tok for tok in lang_toks if keep_token(tok)]
        all_nonbpe_toks = all_nonbpe_toks.union(lang_toks)

    nonbpe_toks = [AddedToken(content=f" {txt}") for txt in sorted(all_nonbpe_toks)]
    print(f"No. of non-BPE tokens adding: {len(nonbpe_toks)}")

    n_old = len(tokenizer)

    # add new tokens to tokenizer
    skipped_toks, added_toks = [], []
    for tok in nonbpe_toks:
        n_added = 0

        if tok.content.strip() not in tokenizer.get_vocab():
            n_added = tokenizer.add_tokens(new_tokens=tok)

        if n_added == 0:
            skipped_toks.append(tok)
        else:
            added_toks.append(tok)

    n_new = len(tokenizer)
    n_added = n_new - n_old
    return n_added, skipped_toks, added_toks


##------------------------------------------------------------------------##

tokendict = load_files()

if tokenizer_name == "codet5":
    tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base")
else:
    raise AssertionError("tokenizer name not supported")

n_added, skipped_toks, added_toks = add_tokens(tokenizer, tokendict)

outpath = f"../artifacts/tokenizer/{tokenizer_name}/tokenizer"
tokenizer.save_pretrained(outpath)
